import torch
import torchvision.models as models
from torchvision import transforms as trn
import os
import joblib
from PIL import Image
import time
import argparse

parser = argparse.ArgumentParser(description='pretrained mobilenet')
parser.add_argument('--input_txt', default='', type=str, help='dataset path')
parser.add_argument('--save_folder', default='./testout/', type=str, help='Dir to save txt results')
parser.add_argument('--arch', default='resnet18', type=str)
parser.add_argument('--cpu', action="store_true", default=False, help='Use cpu inference')
parser.add_argument('--split',default=0, type=int)
args = parser.parse_args()


# load the class label
file_name = 'categories_places365.txt'
if not os.access(file_name, os.W_OK):
    synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt'
    os.system('wget ' + synset_url)
classes = list()
with open(file_name) as class_file:
    for line in class_file:
        classes.append(line.strip().split(' ')[0][3:])
classes = tuple(classes)

class dataset(torch.utils.data.Dataset):
    def __init__(self, list, transforms):
        self.image_list=[line.strip() for line in open(list, 'r')]
        self.transforms = transforms


    def __getitem__(self, index):
        _path = self.image_list[index]
        data = Image.open(_path).convert('RGB')
        return self.transforms(data), _path

    def __len__(self):
        return len(self.image_list)

if __name__ == '__main__':
    torch.set_grad_enabled(False)
    # th architecture to use
    arch = args.arch
    # load the pre-trained weights
    model_file = '%s_places365.pth.tar' % arch
    if not os.path.exists('/scratch/shared/beegfs/yuki/models/'+model_file):
        # from places2.csail.mit.edu/models_places365/resnet18_places365.pth.tar
        print('please download manually.')
        quit()
    model = models.__dict__[arch](num_classes=365)
    checkpoint = torch.load('/scratch/shared/beegfs/yuki/models/'+model_file, map_location='cpu')
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    model.eval()

    BS=8
    print('Finished loading model!')
    print(model)
    # cudnn.benchmark = True
    device = torch.device("cpu" if args.cpu else "cuda")
    model = model.to(device)
    preprocess = trn.Compose([
        trn.Resize((256,256)),
        trn.CenterCrop(224),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    test_loader = torch.utils.data.DataLoader(
        dataset(args.input_txt, preprocess),
        batch_size=BS, shuffle=False, num_workers=4,
        pin_memory= not args.cpu
    )
    num_images = len(test_loader.dataset)

    # testing begin
    amaxes = []
    paths = []
    for i, (imgs, image_paths) in enumerate(test_loader):
        now = time.time()
        imgs = imgs.to(device)
        out = model(imgs)
        amax = torch.argmax(out, dim=-1).cpu().numpy()
        amaxes.extend(amax)
        paths.extend(image_paths)
        if i % 10 == 0:
            print(f"im_detect: {i*BS + 1:5}/{num_images} Time: {(time.time()-now):.3f}s",
                  f"== {BS/(time.time()-now ) :.1f}Hz", flush=True)
            # print(len(np.unique(amaxes)),np.max(amaxes),np.mean(amaxes))
            # print(image_paths[-5:])
    amaxes =  [classes[i] for i in amaxes]
    joblib.dump(amaxes, os.path.join(args.save_folder, str(args.split)+'-places365_dets.pkl'))
    joblib.dump(paths, os.path.join(args.save_folder, str(args.split) +'-places365_paths.pkl'))

